# -*- coding:utf-8 -*-

import tensorflow as tf
from config.glob.global_pool import global_pool

"""
验证方法
"""


def get_accuracy(y_true, y_pred):
    """
    正确率
    :param y_true:
    :param y_pred:
    :return:
    """
    if global_pool.config.accuracy == 'default':
        precision = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y_true, 1))
        accuracy = tf.reduce_mean(tf.cast(precision, 'float'))
    else:
        accuracy = tf.constant(0.0,  dtype=tf.float32)
    return accuracy


